"""Gaussian Mixture Policy."""
from collections import OrderedDict

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

from softlearning.models.feedforward import feedforward_model
from softlearning.models.utils import flatten_input_structure, create_inputs
from softlearning.utils.tensorflow import nest
from .base_policy import LatentSpacePolicy



# Generator

# class HyperCritic():
class GraphicalPolicy(LatentSpacePolicy):
    def __init__(self,
                 input_shapes,
                 output_shape,
                 action_range,
                 *args,
                 squash=True,
                 preprocessors=None,
                 name=None,
                 num_mixture=4,
                 latentVdim = 10,
                 **kwargs):
        assert (np.all(action_range == np.array([[-1], [1]]))), (
            "The action space should be scaled to (-1, 1)."
            " TODO(hartikainen): We should support non-scaled actions spaces.")

        self._Serializable__initialize(locals())

        self._action_range = action_range
        self._input_shapes = input_shapes
        self._output_shape = output_shape
        self._squash = squash
        self._name = name
        self._num_mixture = num_mixture
        self._latent_dim = latentVdim

        self._PI = tf.constant(np.asarray([1. / self._num_mixture, ] * self._num_mixture, dtype=np.float32))
        self._prior_k = tf.distributions.Categorical(probs=self._PI)
        # with tf.compat.v1.variable_scope("generator", reuse=tf.compat.v1.AUTO_REUSE):
        #     self._shared_trainable_mu = tf.compat.v1.get_variable("mu", initializer=np.random.normal(
        #         size=(self._num_mixture, self._latent_dim)).astype(np.float32), trainable=True)

        # self._batch_size  = self._input_shapes[0]


        super(GraphicalPolicy, self).__init__(*args, **kwargs)
        inputs_flat = create_inputs(input_shapes)
        preprocessors_flat = (
            flatten_input_structure(preprocessors)
            if preprocessors is not None
            else tuple(None for _ in inputs_flat))

        assert len(inputs_flat) == len(preprocessors_flat), (
            inputs_flat, preprocessors_flat)

        preprocessed_inputs = [
            preprocessor(input_) if preprocessor is not None else input_
            for preprocessor, input_
            in zip(preprocessors_flat, inputs_flat)
        ]

        def cast_and_concat(x):
            x = nest.map_structure(
                lambda element: tf.cast(element, tf.float32), x)
            x = nest.flatten(x)
            x = tf.concat(x, axis=-1)
            return x

        def sample_gumbel(shape, eps=1e-20):
            """Sample from Gumbel(0, 1)"""
            U = tf.random_uniform(shape, minval=0, maxval=1)
            return -tf.log(-tf.log(U + eps) + eps)

        def gumbel_softmax_sample(logits, temperature):
            """ Draw a sample from the Gumbel-Softmax distribution"""
            y = logits + sample_gumbel(tf.shape(logits))
            return tf.nn.softmax(y / temperature)

        def gumbel_softmax(logits, temperature, hard=True):
            """Sample from the Gumbel-Softmax distribution and optionally discretize.
            Args:
              logits: [batch_size, n_class] unnormalized log-probs
              temperature: non-negative scalar
              hard: if True, take argmax, but differentiate w.r.t. soft sample y
            Returns:
              [batch_size, n_class] sample from the Gumbel-Softmax distribution.
              If hard=True, then the returned sample will be one-hot, otherwise it will
              be a probabilitiy distribution that sums to 1 across classes
            """
            y = gumbel_softmax_sample(logits, temperature)
            if hard:
                k = tf.shape(logits)[-1]
                # y_hard = tf.cast(tf.one_hot(tf.argmax(y,1),k), y.dtype)
                y_hard = tf.cast(tf.equal(y, tf.reduce_max(y, 1, keep_dims=True)), y.dtype)
                y = tf.stop_gradient(y_hard - y) + y
            return y

        conditions = tf.keras.layers.Lambda(
            cast_and_concat
        )(preprocessed_inputs)

        self.condition_inputs = inputs_flat


        batch_size = tf.keras.layers.Lambda(
            lambda x: tf.shape(input=x)[0])(conditions)

        self._hyper_p_k = gumbel_softmax(self._condition_state_net()(conditions),temperature=1.0,hard=True)

        base_distribution = tfp.distributions.MultivariateNormalDiag(
            loc=tf.zeros(output_shape),
            scale_diag=tf.ones(output_shape))

        base_distribution_latent = tfp.distributions.MultivariateNormalDiag(
            loc=tf.zeros(self._latent_dim),
            scale_diag=tf.ones(self._latent_dim))

        latents = tf.keras.layers.Lambda(
            lambda batch_size: base_distribution.sample(batch_size)
        )(batch_size)

        epi_latent = tf.keras.layers.Lambda(
            lambda batch_size: base_distribution_latent.sample(batch_size)
        )(batch_size)

        self.hyper_p_z = tf.add(self._trainable_shift_net()(tf.cast(self._hyper_p_k, tf.float32)), epi_latent)


        input_for_shiftmodel = tf.concat([conditions, self.hyper_p_z], axis=-1)

        shift_and_log_scale_diag = self._shift_and_log_scale_diag_net(
            output_size=np.prod(output_shape) * 2,
        )(input_for_shiftmodel)
        shift, log_scale_diag = tf.keras.layers.Lambda(
            lambda shift_and_log_scale_diag: tf.split(
                shift_and_log_scale_diag,
                num_or_size_splits=2,
                axis=-1)
        )(shift_and_log_scale_diag)

        self.latents_model = tf.keras.Model(self.condition_inputs, latents)

        self.latents_input = tf.keras.layers.Input(
            shape=output_shape, name='latents')



        def raw_actions_fn(inputs):
            shift, log_scale_diag, latents = inputs
            bijector = tfp.bijectors.Affine(
                shift=shift,
                scale_diag=tf.exp(log_scale_diag))
            actions = bijector.forward(latents)
            return actions

        raw_actions = tf.keras.layers.Lambda(
            raw_actions_fn
        )((shift, log_scale_diag, latents))

        raw_actions_for_fixed_latents = tf.keras.layers.Lambda(
            raw_actions_fn
        )((shift, log_scale_diag, self.latents_input))

        squash_bijector = (
            tfp.bijectors.Tanh()
            if self._squash
            else tfp.bijectors.Identity())

        actions = tf.keras.layers.Lambda(
            lambda raw_actions: squash_bijector.forward(raw_actions)
        )(raw_actions)
        self.actions_model = tf.keras.Model(self.condition_inputs, actions)

        actions_for_fixed_latents = tf.keras.layers.Lambda(
            lambda raw_actions: squash_bijector.forward(raw_actions)
        )(raw_actions_for_fixed_latents)
        self.actions_model_for_fixed_latents = tf.keras.Model(
            (*self.condition_inputs, self.latents_input),
            actions_for_fixed_latents)

        deterministic_actions = tf.keras.layers.Lambda(
            lambda shift: squash_bijector.forward(shift)
        )(shift)

        self.deterministic_actions_model = tf.keras.Model(
            self.condition_inputs, deterministic_actions)

        def log_pis_fn(inputs):
            shift, log_scale_diag, actions = inputs
            base_distribution = tfp.distributions.MultivariateNormalDiag(
                loc=tf.zeros(output_shape),
                scale_diag=tf.ones(output_shape))
            bijector = tfp.bijectors.Chain((
                squash_bijector,
                tfp.bijectors.Affine(
                    shift=shift,
                    scale_diag=tf.exp(log_scale_diag)),
            ))
            distribution = (
                tfp.distributions.TransformedDistribution(
                    distribution=base_distribution,
                    bijector=bijector))

            log_pis = distribution.log_prob(actions)[:, None]
            return log_pis
        #p(a|s) (h,k)
        self.actions_input = tf.keras.layers.Input(
            shape=output_shape, name='actions')

        log_pis = tf.keras.layers.Lambda(
            log_pis_fn)([shift, log_scale_diag, actions])

        log_pis_for_action_input = tf.keras.layers.Lambda(
            log_pis_fn)([shift, log_scale_diag, self.actions_input])

        self.log_pis_model = tf.keras.Model(
            (*self.condition_inputs, self.actions_input),
            log_pis_for_action_input)

        self.diagnostics_model = tf.keras.Model(
            self.condition_inputs,
            (shift, log_scale_diag, log_pis, raw_actions, actions))

    # def sample_gumbel(shape, eps=1e-20):
    #     # Sample from Gumbel(0, 1)
    #     U = tf.random_uniform(shape, minval=0, maxval=1)
    #     return -tf.log(-tf.log(U + eps) + eps)

    def _shift_and_log_scale_diag_net(self, input_shapes, output_size):
        raise NotImplementedError

    def _trainable_shift_net(self):
        raise NotImplementedError

    def __hyper_generator(self):
        raise NotImplementedError

    def _condition_state_net(self):
        raise NotImplementedError

    def get_weights(self):
        return self.actions_model.get_weights()

    def set_weights(self, *args, **kwargs):
        return self.actions_model.set_weights(*args, **kwargs)

    @property
    def trainable_variables(self):
        return self.actions_model.trainable_variables

    def get_diagnostics(self, inputs):
        """Return diagnostic information of the policy.

        Returns the mean, min, max, and standard deviation of means and
        covariances.
        """
        (shifts_np,
         log_scale_diags_np,
         log_pis_np,
         raw_actions_np,
         actions_np) = self.diagnostics_model.predict(inputs)

        return OrderedDict((
            ('shifts-mean', np.mean(shifts_np)),
            ('shifts-std', np.std(shifts_np)),

            ('log_scale_diags-mean', np.mean(log_scale_diags_np)),
            ('log_scale_diags-std', np.std(log_scale_diags_np)),

            ('-log-pis-mean', np.mean(-log_pis_np)),
            ('-log-pis-std', np.std(-log_pis_np)),

            ('raw-actions-mean', np.mean(raw_actions_np)),
            ('raw-actions-std', np.std(raw_actions_np)),

            ('actions-mean', np.mean(actions_np)),
            ('actions-std', np.std(actions_np)),
            ('actions-min', np.min(actions_np)),
            ('actions-max', np.max(actions_np)),
        ))


# Hyper Generator

class FeedforwardGraphicalPolicy(GraphicalPolicy):
    def __init__(self,
                 hidden_layer_sizes,
                 activation='relu',
                 output_activation='linear',
                 * args,
                 **kwargs):
        self._hidden_layer_sizes = hidden_layer_sizes
        self._activation = activation
        self._output_activation = output_activation

        self._Serializable__initialize(locals())
        super(FeedforwardGraphicalPolicy, self).__init__(*args, **kwargs)

    def _shift_and_log_scale_diag_net(self, output_size):
        shift_and_log_scale_diag_net = feedforward_model(
            hidden_layer_sizes=self._hidden_layer_sizes,
            output_size=output_size,
            activation=self._activation,
            output_activation=self._output_activation)

        return shift_and_log_scale_diag_net

    def _condition_state_net(self):
        condition_net = feedforward_model(
            hidden_layer_sizes=self._hidden_layer_sizes,
            output_size=self._num_mixture,
            activation=self._activation,
            output_activation=self._output_activation, name='condition_gen')

        return condition_net

    def _trainable_shift_net(self):
        trainable_shift_net = feedforward_model(hidden_layer_sizes=(),
                output_size=self._latent_dim,activation='linear',output_activation='linear',name='hyper_generator')

        return trainable_shift_net
    #


#